"""
© 2021 This work is licensed under a CC-BY-NC-SA license.
Title: *"Behavioral cloning in recurrent spiking networks: A comprehensive framework"*
**Authors:** Anonymus
"""


import gbc
import matplotlib.pyplot as plt
import numpy as np
from numpy import savetxt,loadtxt
from tqdm import trange
import test_cloner
import os.path

folder_save = "RewardsErrors"
folder_weights = "weights"
folder_expert = "expert"

time_steps = 400
t0 = 0

N, I, O, T = 500, 24, 4, time_steps;
shape = (N, I, O, T);

n_examples = 1
n_examples_validation = 1

num_iterations = 500
epochs_rec =  1
num_iterations_out = 8
epochs_out =  25

n_reps = 10

# Here we define our model

dt = .001# / T;
tau_m = 4. * dt;
tau_s = 2. * dt;
tau_ro = 2. * dt;
tau_targ = .05 * dt;
beta_s  = np.exp (-dt / tau_s);
beta_ro = np.exp (-dt / tau_ro);
beta_targ = np.exp (-dt / tau_targ);

sigma_teach = 3.;
sigma_input = 2.;
offT = 0;
dv = 1/5.;
alpha = .3#.3
alpha_rout = .000375
Vo = -4;
h = -4;
s_inh = 20;

# Here we build the dictionary of the simulation parameters
par = {'tau_m' : tau_m, 'tau_s' : tau_s, 'tau_ro' : tau_ro, 'beta_ro' : beta_ro,'beta_targ' : beta_targ,
	   'dv' : dv, 'alpha' : alpha, 'Vo' : Vo, 'h' : h, 's_inh' : s_inh,
	   'N' : N, 'T' : T, 'dt' : dt, 'offT' : offT, 'alpha_rout' : alpha_rout,
	   'sigma_input' : sigma_input, 'sigma_teach' : sigma_teach, 'shape' : shape};

steps = time_steps;

log_tau_vals = np.linspace(-1.5,1.5,10)
tau_vals = 10**(log_tau_vals)


for n_rep in range(n_reps):

	Jteach = np.random.normal(0,1,size=(N,O))
	Jin = np.random.normal(0,1,size=(N,I))
	B = np.random.normal(0,1,size=(N,N))*.1
	J = np.random.normal(0,1,size=(N,N))*.1
	P = np.random.normal (0.,  1. , size = (N, O) )

	np.save(os.path.join(folder_weights,"B.npy"), B )#
	np.save(os.path.join(folder_weights,"Jin.npy"), Jin )#
	np.save(os.path.join(folder_weights,"Jteach.npy"), Jteach )#
	np.save(os.path.join(folder_weights,"J.npy"), J)#

	for ndx_tau in range(len(tau_vals)):
		rank = 0
		tau_targ = tau_vals[ndx_tau]* dt;
		beta_targ = np.exp (-dt / tau_targ);

		del gbc

		import gbc
		gbc = gbc.GBC (par);

		gbc.par["beta_targ"] = beta_targ

		gbc.J = gbc.J*0.
		np.fill_diagonal (gbc.J, 0.);

		gbc.Jin = np.load(os.path.join(folder_weights,"Jin.npy"))*sigma_input
		gbc.Jteach = np.load(os.path.join(folder_weights,"Jteach.npy"))*sigma_teach
		B = np.load(os.path.join(folder_weights,"B.npy"))

		a_coll = []
		s_coll = []

		# Load Data To Clone

		demonstrations_nums = [1]
		validation_nums = [3]

		a_coll_load = [ loadtxt(os.path.join(folder_expert,"action_" + str( demonstrations_nums[n_batch]  ) + ".csv")) for n_batch in range(n_examples)]
		s_coll_load = [ loadtxt(os.path.join(folder_expert,"state_" + str( demonstrations_nums[n_batch]  ) + ".csv")) for n_batch in range(n_examples)]

		a_coll_val_load = [ loadtxt(os.path.join(folder_expert,"action_" + str( validation_nums[n_batch]  ) + ".csv")) for n_batch in range(n_examples_validation)]
		s_coll_val_load = [ loadtxt(os.path.join(folder_expert,"state_" + str( validation_nums[n_batch]  ) + ".csv")) for n_batch in range(n_examples_validation)]

		itargets = [gbc.implement (s_coll_load[n_batch][0:I,0+t0:time_steps+t0],a_coll_load[n_batch][:,0+t0:time_steps+t0],time_steps)[0] for n_batch in range(n_examples)]
		inputs = [gbc.implement (s_coll_load[n_batch][0:I,0+t0:time_steps+t0],a_coll_load[n_batch][:,0+t0:time_steps+t0],time_steps)[1] for n_batch in range(n_examples)]

		itargets_val = [gbc.implement (s_coll_val_load[n_batch][0:I,0+t0:time_steps+t0],a_coll_val_load[n_batch][:,0+t0:time_steps+t0],time_steps)[0] for n_batch in range(n_examples_validation)]
		inputs_val = [gbc.implement (s_coll_val_load[n_batch][0:I,0+t0:time_steps+t0],a_coll_val_load[n_batch][:,0+t0:time_steps+t0],time_steps)[1] for n_batch in range(n_examples_validation)]

		a_aggr = np.zeros((O,0),dtype=float)
		s_aggr = np.zeros((I,0),dtype=float)
		S_aggr = np.zeros((N,0),dtype=float)

		for n_batch in range(n_examples):

			a_coll = a_coll_load[n_batch][:,0+t0:time_steps+t0]
			a_coll[:,0:offT] = 0
			a_aggr = np.concatenate( (a_aggr.T,a_coll.T) ).T
			S_gen, action = gbc.compute (inputs[n_batch][:,:]);
			S_aggr = np.concatenate( (S_aggr.T,S_gen.T) ).T

		gbc.Jout = gbc.Jout*0.
		#B = B*2/np.sqrt(rank)

		if rank>-1:

			for kk in trange (num_iterations, leave = False, desc = 'Cloning'):

				internal_error=0
				internal_error_val=0
				external_error=0
				external_error_val=0

				for n_batch in range(n_examples):

					a_coll = a_coll_load[n_batch][:,0+t0:time_steps+t0]
					a_coll[:,0:offT] = 0
					s_coll = s_coll_load[n_batch][0:I,0+t0:time_steps+t0]
					DS = gbc.clone_targ ( s_coll, a_coll , itargets[n_batch][:,:], 0. ,B , epochs = 1,rank = rank,clumped = True);
					S_gen, action = gbc.compute (inputs[n_batch][:,:]);
					internal_error += np.mean(np.abs(S_gen - itargets[n_batch][:,:]))

		internal_error_coll = []
		external_error_coll = []
		internal_error_val_coll = []
		external_error_val_coll = []

		dist_coll = []
		dist_val_coll = []
		dist_val_m_coll = []
		dist_test_coll = []

		a_aggr = np.zeros((O,0),dtype=float)
		s_aggr = np.zeros((I,0),dtype=float)
		S_aggr = np.zeros((N,0),dtype=float)

		for n_batch in range(n_examples):

			a_coll = a_coll_load[n_batch][:,0+t0:time_steps+t0]
			a_coll[:,0:offT] = 0
			a_aggr = np.concatenate( (a_aggr.T,a_coll.T) ).T
			S_gen, action = gbc.compute (inputs[n_batch][:,:]);
			S_aggr = np.concatenate( (S_aggr.T,S_gen.T) ).T

		avg_reward_collection = []
		validation_error_collection = []
		training_error_collection = []

		for kk in trange (num_iterations_out, leave = False, desc = 'Cloning'):

			external_error=0
			external_error_val=0
			internal_error_val=0

			a_out,error = gbc.clone_ro ( a_aggr , S_aggr  , epochs = epochs_out);
			external_error = error
			external_error = external_error
			avg_reward = test_cloner.test(20,gbc)

			avg_reward_collection.append(avg_reward)

			for n_batch in range(n_examples_validation):

				S_gen_val, action_val = gbc.compute (inputs_val[n_batch][:,0:time_steps])
				external_error_val += np.std(action_val[:,1:time_steps-1]-a_coll_val_load[n_batch][:,0:time_steps-2])

			validation_error_collection.append(external_error_val)
			training_error_collection.append(external_error)

			print("- ext error: " + str( "{: 0.4f}".format(external_error) ))
			print("- ext error val: " + str( "{: 0.4f}".format(external_error_val) ))
			print("- avg reward: " + str( "{: 0.4f}".format(avg_reward) ))

			np.save(os.path.join(folder_save,"training_error_collection" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[ndx_tau]) + ".npy"), training_error_collection)#
			np.save(os.path.join(folder_save,"validation_error_collection" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[ndx_tau]) + ".npy"), validation_error_collection)#_not_clumped_rank" + str(rank) + "
			np.save(os.path.join(folder_save,"avg_reward" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[ndx_tau]) + ".npy"), avg_reward_collection)#_not_clumped_rank" + str(rank) + "
			np.save(os.path.join(folder_save,"DS" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[ndx_tau]) + ".npy"), DS)#_not_clumped_rank" + str(rank) + "

		gbc.save (os.path.join(folder_save,"model_N" + "_st" + str(sigma_teach) + "_" + str(n_rep) + "_not_clumped_tau" + str(tau_vals[ndx_tau]) + ".npy" ))
